Skip to content

fix(MemoryFormatOpsPass): preserve input dim_order for clone/to_copy with no memory_format kwarg#17611

Merged
GregoryComer merged 10 commits intopytorch:mainfrom
nefainl:fix/16032-memory-format-ops-pass-preserve-format
Mar 5, 2026
Merged

fix(MemoryFormatOpsPass): preserve input dim_order for clone/to_copy with no memory_format kwarg#17611
GregoryComer merged 10 commits intopytorch:mainfrom
nefainl:fix/16032-memory-format-ops-pass-preserve-format

Conversation

@nefainl
Copy link
Copy Markdown
Contributor

@nefainl nefainl commented Feb 21, 2026

Summary

Fixes #16032

This PR fixes MemoryFormatOpsPass to correctly handle torch.preserve_format semantics for clone() and _to_copy.default operations.

Root cause: When clone() or _to_copy is called without an explicit memory_format kwarg, the pass was defaulting to torch.contiguous_format, causing the output dim_order to be [0,1,2,3] (contiguous) even when the input was channels-last [0,2,3,1]. This caused runtime assertion failures:

Code=18 InvalidArgument: tensors_have_same_dim_order(self, out)

Fix: Change the default from torch.contiguous_format to torch.preserve_format, and derive dim_order from the input tensor's dim_order() when preserve_format is used.

This is a minimal, focused fix following the guidance from @GregoryComer in the discussion on PR #17463.

Changes

  • exir/passes/memory_format_ops_pass.py (+29/-5 lines):

    • Default memory_format to torch.preserve_format instead of torch.contiguous_format
    • When preserve_format, derive dim_order from input_tensor.dim_order()
    • Fallback to contiguous if no input tensor available (e.g., empty())
  • exir/tests/test_passes.py (+130 lines):

Standalone Reproduction

import torch
from torch.export import export
from executorch.exir import to_edge, EdgeCompileConfig

class ConvClone(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv(x).clone()

model = ConvClone().to(memory_format=torch.channels_last)
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

exported = export(model, (x,))
edge = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Before fix: clone node has dim_order=(0,1,2,3) - BUG
# After fix: clone node has dim_order=(0,2,3,1) - CORRECT
for node in edge.exported_program().graph_module.graph.nodes:
    if "_clone_dim_order" in str(node.target):
        print(f"clone dim_order: {tuple(node.meta['val'].dim_order())}")

Test Plan

  • All 3 new tests pass
  • Verified fix with standalone reproduction script
  • No changes to existing tests required

Related

…with no memory_format kwarg

Issue pytorch#16032: clone() and _to_copy operations with no explicit memory_format
kwarg were defaulting to contiguous dim_order, causing runtime assertion
failures when cloning channels-last tensors.

Changes:
- Default memory_format to torch.preserve_format instead of torch.contiguous_format
- When preserve_format, derive dim_order from input tensor's dim_order()
- Simplify type annotation: dim_order is always assigned, no Optional needed

Tests:
- test_clone_no_kwarg_preserves_channels_last_dim_order: core repro case
- test_clone_contiguous_format_kwarg_stays_contiguous: regression guard
- test_to_copy_no_kwarg_preserves_channels_last_dim_order: _to_copy path
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 21, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17611

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 4f91bc4 with merge base 9f2f005 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 21, 2026
@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 21, 2026

@pytorchbot label "release notes: exir"

@pytorch-bot pytorch-bot Bot added the release notes: exir Changes to any dialects and passes on these dialects, such as memory planning label Feb 21, 2026
@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 21, 2026

Request to the reviewers: Could you please add @GregoryComer as well? I have received very helpful advice from him in the other PR and it makes sense to have him also here so he can edit if needed. Note that some of the other changes will still be handled in another PR (where it also makes sense to have him review).

Copy link
Copy Markdown
Member

@GregoryComer GregoryComer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nefainl Thanks for the update. This PR looks good, we can merge with a few small changes.

In addition to the few small nits, can you update the tests to also run the model and verify that the output tensor has the correct memory format / dim order?

Comment thread exir/passes/memory_format_ops_pass.py Outdated
Comment thread exir/tests/test_passes.py Outdated
Comment thread exir/tests/test_passes.py Outdated
@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 26, 2026

@nefainl Thanks for the update. This PR looks good, we can merge with a few small changes.

In addition to the few small nits, can you update the tests to also run the model and verify that the output tensor has the correct memory format / dim order?

@GregoryComer Thank you for looking into it, I will implement your suggestions and look into the lint runner output tomorrow and make the required fixes!

@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 26, 2026

Correction, I will look in detail at all the failing checks to see how to resolve them tomorrow.

@GregoryComer
Copy link
Copy Markdown
Member

Correction, I will look in detail at all the failing checks to see how to resolve them tomorrow.

Thanks. You can ignore the samsung failures. The other failures (except lint) look like flakes. I'll re-try the jobs.

@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 27, 2026

Quick update on this PR:

  • Updated MemoryFormatOpsPass so that:
    • memory_format defaults to torch.preserve_format instead of torch.contiguous_format.
    • For preserve-format (None or torch.preserve_format), dim_order is now derived from the input tensor’s dim_order() when available.
    • When there is no single input tensor (e.g. immutable list inputs for ops like torch.stack), we conservatively fall back to contiguous dim_order = list(range(ndim)).
  • Kept the existing behavior for explicit memory_format values (contiguous, channels_last, etc.), which still go through get_dim_order(mem_format, ndim).
  • Addressed the style nits by removing explicit GitHub issue references from comments/docstrings and keeping them in the commit/PR description instead.

Locally, the new TestMemoryFormatOpsPassPreserveFormat tests pass and confirm:

  • clone() with no memory_format kwarg preserves channels-last dim_order in the exported graph.
  • clone(memory_format=torch.contiguous_format) produces a contiguous layout.
  • to(dtype=...) without memory_format preserves channels-last dim_order.

Let me know if you’d like any additional end-to-end checks (e.g. running the exported graph and inspecting runtime dim_order) beyond what’s already covered in the tests.

NefAI added 4 commits February 27, 2026 21:12
Clarify preserve_format behavior and extend MemoryFormatOpsPass tests to
run the models and assert that output tensors have the expected memory
format / dim order.
Replace the generator-based dim_order construction with a list
comprehension to satisfy FLAKE8 C400 and add the missing blank line
before the new test class to align with PEP 8 spacing.
@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 27, 2026

Correction, I will look in detail at all the failing checks to see how to resolve them tomorrow.

Thanks. You can ignore the samsung failures. The other failures (except lint) look like flakes. I'll re-try the jobs.

Thanks a lot for the input and the helpful review comments / nits etc. I think I have covered all now. If you see any additional items please let me know!

@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Feb 27, 2026

Each of the three TestMemoryFormatOpsPassPreserveFormat tests now:

  1. Runs the model eagerly with the appropriate input (channels-last or contiguous as needed).
  2. Asserts the output tensor’s memory format via is_contiguous() / is_contiguous(memory_format=torch.channels_last) so we verify the actual tensor layout, not only the exported graph.
  3. Then exports and checks that the graph’s dim_order metadata matches (e.g. (0, 2, 3, 1) for channels-last, (0, 1, 2, 3) for contiguous).

So we’re already checking both runtime output layout and export metadata. If you’d like additional end-to-end checks (e.g. running the exported ExecuTorch program and inspecting runtime dim_order), we can add those.

@nefainl nefainl requested a review from GregoryComer February 27, 2026 20:59
@GregoryComer
Copy link
Copy Markdown
Member

Thanks - looks good. I'm going to run a few more tests and I'll merge if everything looks good.

@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Mar 5, 2026

@GregoryComer has imported this pull request. If you are a Meta employee, you can view this in D95409924.

@nefainl
Copy link
Copy Markdown
Contributor Author

nefainl commented Mar 5, 2026

Thanks - looks good. I'm going to run a few more tests and I'll merge if everything looks good.

Great, thanks for your help earlier in the solution direction! I will look into the other PR in more detail later.

@GregoryComer GregoryComer merged commit 4c20ef1 into pytorch:main Mar 5, 2026
154 of 157 checks passed
jpiat pushed a commit to jpiat/executorch that referenced this pull request Mar 17, 2026
…with no memory_format kwarg (pytorch#17611)

## Summary

Fixes pytorch#16032

This PR fixes `MemoryFormatOpsPass` to correctly handle
`torch.preserve_format` semantics for `clone()` and `_to_copy.default`
operations.

**Root cause:** When `clone()` or `_to_copy` is called without an
explicit `memory_format` kwarg, the pass was defaulting to
`torch.contiguous_format`, causing the output `dim_order` to be
`[0,1,2,3]` (contiguous) even when the input was channels-last
`[0,2,3,1]`. This caused runtime assertion failures:
```
Code=18 InvalidArgument: tensors_have_same_dim_order(self, out)
```

**Fix:** Change the default from `torch.contiguous_format` to
`torch.preserve_format`, and derive `dim_order` from the input tensor's
`dim_order()` when preserve_format is used.

This is a minimal, focused fix following the guidance from @GregoryComer
in the discussion on PR pytorch#17463.

## Changes

- **`exir/passes/memory_format_ops_pass.py`** (+29/-5 lines):
- Default `memory_format` to `torch.preserve_format` instead of
`torch.contiguous_format`
- When preserve_format, derive `dim_order` from
`input_tensor.dim_order()`
- Fallback to contiguous if no input tensor available (e.g., `empty()`)

- **`exir/tests/test_passes.py`** (+130 lines):
- `test_clone_no_kwarg_preserves_channels_last_dim_order`: Core repro
case for pytorch#16032
- `test_clone_contiguous_format_kwarg_stays_contiguous`: Regression
guard
- `test_to_copy_no_kwarg_preserves_channels_last_dim_order`: Verifies
`_to_copy.default` path

## Standalone Reproduction

```python
import torch
from torch.export import export
from executorch.exir import to_edge, EdgeCompileConfig

class ConvClone(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv(x).clone()

model = ConvClone().to(memory_format=torch.channels_last)
x = torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last)

exported = export(model, (x,))
edge = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False))

# Before fix: clone node has dim_order=(0,1,2,3) - BUG
# After fix: clone node has dim_order=(0,2,3,1) - CORRECT
for node in edge.exported_program().graph_module.graph.nodes:
    if "_clone_dim_order" in str(node.target):
        print(f"clone dim_order: {tuple(node.meta['val'].dim_order())}")
```

## Test Plan

- [x] All 3 new tests pass
- [x] Verified fix with standalone reproduction script
- [x] No changes to existing tests required

## Related

- Fixes pytorch#16032
- Supersedes pytorch#17463 (this is the minimal fix extracted from that PR per
reviewer feedback)

---------

Co-authored-by: NefAI <info@nefai.nl>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: exir Changes to any dialects and passes on these dialects, such as memory planning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dim Order Validation Inconsistency for Edge / Ambiguous Cases

3 participants